"""Plot experiment results.

Usage: python disp_experiments.py --input_dir=experiments
"""

import json
import os

from absl import app
from absl import flags
import numpy as np
from sklearn.utils.extmath import randomized_svd

flags.DEFINE_string('input_dir', None, 'Directory with measurements')
FLAGS = flags.FLAGS
DS = ['mnist', 'fashion_mnist', 'smallnorb', 'colorectal_histology']

def format_tt(algo):
  return '\\texttt{' + algo.replace('_', '\\_') + '}'

def main(argv) -> None:
  results = {}
  for ds in DS:
    input_file = os.path.join(FLAGS.input_dir, f'{ds}.json')
    if not os.path.exists(input_file):
      raise ValueError(f'Path {input_file} does not exist')
    with open(input_file, 'r') as fp:
      results[ds] = json.load(fp)

  rank = 20
  print(' & '.join(['Algorithm'] + [format_tt(ds) for ds in DS]) + ' \\\\')
  for metric in ['loss', 'time']:
    print('\\hline')
    for algo in ['svd_w', 'adam', 'em', 'greedy', 'sample', 'svd']:
      results_metric = [results[ds][algo][metric][rank-1] for ds in DS]
      results_metric = [f'{x:.4f}' for x in results_metric]
      metric_print = metric
      if metric == 'time':
        metric_print += ' (s)'
      print(' & '.join([f'{format_tt(algo)} {metric_print}'] + results_metric) + ' \\\\')



if __name__ == '__main__':
  app.run(main)
